{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MhoQ0WE77laV"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-07T23:10:58.102099Z",
     "iopub.status.busy": "2023-11-07T23:10:58.101852Z",
     "iopub.status.idle": "2023-11-07T23:10:58.105808Z",
     "shell.execute_reply": "2023-11-07T23:10:58.105236Z"
    },
    "id": "_ckMIh7O7s6D"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jYysdyb-CaWM"
   },
   "source": [
    "# 分布式输入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FbVhjPpzn6BM"
   },
   "source": [
    "[tf.distribute](https://tensorflow.google.cn/guide/distributed_training) API 为用户提供了一种简单的方法，可将训练范围从一台计算机扩展到多台计算机。扩展模型时，用户还必须将其输入分布到多个设备上。`tf.distribute` 提供了相应的 API，您可以利用这些 API 在设备之间自动分布输入。\n",
    "\n",
    "本指南将展示使用 `tf.distribute` API 创建分布式数据集和迭代器的不同方法。此外，还将涵盖以下主题：\n",
    "\n",
    "- 使用 `tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function` 时的用法、分片和批处理选项。\n",
    "- 遍历分布式数据集的不同方式。\n",
    "- Differences between `tf.distribute.Strategy.experimental_distribute_dataset`/`tf.distribute.Strategy.distribute_datasets_from_function` APIs and `tf.data` APIs as well any limitations that users may come across in their usage.\n",
    "\n",
    "本指南不介绍如何将分布式输入与 Keras API 一起使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MM6W__qraV55"
   },
   "source": [
    "## 分布式数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lNy9GxjSlMKQ"
   },
   "source": [
    "要使用 `tf.distribute` API 扩缩，请使用 `tf.data.Dataset` 表示其输入。`tf.distribute` 可以与 `tf.data.Dataset` 高效地协同工作（例如，通过自动预提取到每个加速器设备和定期性能更新）。如果您有使用除 `tf.data.Dataset` 以外的其他 API 的用例，请参阅本指南中的[张量输入](#tensorinputs)部分。在非分布式训练循环中，首先创建一个 `tf.data.Dataset` 实例，然后迭代各个元素。例如：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:10:58.109756Z",
     "iopub.status.busy": "2023-11-07T23:10:58.109512Z",
     "iopub.status.idle": "2023-11-07T23:11:00.688043Z",
     "shell.execute_reply": "2023-11-07T23:11:00.687256Z"
    },
    "id": "pCu2Jj-21AEf"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-11-07 23:10:58.568819: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-11-07 23:10:58.568882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-11-07 23:10:58.570564: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.15.0-rc1\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# Helper libraries\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:00.691920Z",
     "iopub.status.busy": "2023-11-07T23:11:00.691466Z",
     "iopub.status.idle": "2023-11-07T23:11:01.251410Z",
     "shell.execute_reply": "2023-11-07T23:11:01.250604Z"
    },
    "id": "6cnilUtmKwpa"
   },
   "outputs": [],
   "source": [
    "# Simulate multiple CPUs with virtual devices\n",
    "N_VIRTUAL_DEVICES = 2\n",
    "physical_devices = tf.config.list_physical_devices(\"CPU\")\n",
    "tf.config.set_logical_device_configuration(\n",
    "    physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:01.255836Z",
     "iopub.status.busy": "2023-11-07T23:11:01.255518Z",
     "iopub.status.idle": "2023-11-07T23:11:02.616821Z",
     "shell.execute_reply": "2023-11-07T23:11:02.616047Z"
    },
    "id": "zd4l1ySeLRk1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Available devices:\n",
      "0) LogicalDevice(name='/device:CPU:0', device_type='CPU')\n",
      "1) LogicalDevice(name='/device:CPU:1', device_type='CPU')\n",
      "2) LogicalDevice(name='/device:GPU:0', device_type='GPU')\n",
      "3) LogicalDevice(name='/device:GPU:1', device_type='GPU')\n",
      "4) LogicalDevice(name='/device:GPU:2', device_type='GPU')\n",
      "5) LogicalDevice(name='/device:GPU:3', device_type='GPU')\n"
     ]
    }
   ],
   "source": [
    "print(\"Available devices:\")\n",
    "for i, device in enumerate(tf.config.list_logical_devices()):\n",
    "  print(\"%d) %s\" % (i, device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:02.620893Z",
     "iopub.status.busy": "2023-11-07T23:11:02.620319Z",
     "iopub.status.idle": "2023-11-07T23:11:03.081448Z",
     "shell.execute_reply": "2023-11-07T23:11:03.080671Z"
    },
    "id": "dzLKpmZICaWN"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(16, 1), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "global_batch_size = 16\n",
    "# Create a tf.data.Dataset object.\n",
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
    "\n",
    "@tf.function\n",
    "def train_step(inputs):\n",
    "  features, labels = inputs\n",
    "  return labels - 0.3 * features\n",
    "\n",
    "# Iterate over the dataset using the for..in construct.\n",
    "for inputs in dataset:\n",
    "  print(train_step(inputs))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ihrhYDYRrVLH"
   },
   "source": [
    "为了在尽可能不更改用户现有代码的情况下使用户能够使用 `tf.distribute` 策略，我们引入了两个 API，它们将分配 `tf.data.Dataset` 实例并返回一个分布式数据集对象。随后，用户可以遍历此分布式数据集实例并像以前一样训练自己的模型。现在让我们更详细地看一下这两个 API - `tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function`："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4AXoHhrsbdF3"
   },
   "source": [
    "### `tf.distribute.Strategy.experimental_distribute_dataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5mVuLZhbem8d"
   },
   "source": [
    "#### 用法\n",
    "\n",
    "此 API 将 `tf.data.Dataset` 实例作为输入，并返回 `tf.distribute.DistributedDataset` 实例。您应当使用等于全局批次大小的值对输入数据集进行批处理。此全局批次大小是您要在所有设备中一步处理的样本数。您可以用 Python 样式迭代此分布式数据集，或者使用 `iter` 创建一个迭代器。返回的对象不是 `tf.data.Dataset` 实例，并且不支持以任何方式转换或检查数据集的任何其他 API。如果您没有特定的方式将输入分片到不同副本中，则建议使用此 API。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:03.085820Z",
     "iopub.status.busy": "2023-11-07T23:11:03.085130Z",
     "iopub.status.idle": "2023-11-07T23:11:04.446050Z",
     "shell.execute_reply": "2023-11-07T23:11:04.445250Z"
    },
    "id": "F2VeZUWUj5S4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "}, PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "global_batch_size = 16\n",
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
    "# Distribute input using the `experimental_distribute_dataset`.\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
    "# 1 global batch of data fed to the model in 1 step.\n",
    "print(next(iter(dist_dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QPceDmRht54F"
   },
   "source": [
    "#### 属性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0Qb6nDgxiN_n"
   },
   "source": [
    "##### 批处理\n",
    "\n",
    "`tf.distribute` 使用新的批次大小（等于全局批次大小除以同步副本数）对输入 `tf.data.Dataset` 实例进行重新批处理。同步副本数等于训练期间参与梯度全归约的设备数。当用户在分布式迭代器上调用 `next` 时，将在每个副本上返回数据的每个副本批次大小。经过重新批处理的数据集基数将始终为副本数的倍数。下面是一些示例：\n",
    "\n",
    "- `tf.data.Dataset.range(6).batch(4, drop_remainder=False)`\n",
    "\n",
    "    - 无分布：\n",
    "\n",
    "        - 批次 1：[0, 1, 2, 3]\n",
    "        - 批次 2：[4, 5]\n",
    "\n",
    "    - 分布在 2 个副本上。最后一个批次 ([4, 5]) 被拆分到 2 个副本中。\n",
    "\n",
    "    - 批次 1：\n",
    "\n",
    "        - 副本 1：[0, 1]\n",
    "        - 副本 2：[2, 3]\n",
    "\n",
    "    - 批次 2：\n",
    "\n",
    "        - 副本 1：[4]\n",
    "        - 副本 2：[5]\n",
    "\n",
    "- `tf.data.Dataset.range(4).batch(4)`\n",
    "\n",
    "    - 无分布：\n",
    "        - 批次 1：[0, 1, 2, 3]\n",
    "    - 分布在 5 个副本上：\n",
    "        - 批次 1：\n",
    "            - 副本 1：[0]\n",
    "            - 副本 2：[1]\n",
    "            - 副本 3：[2]\n",
    "            - 副本 4：[3]\n",
    "            - 副本 5：[]\n",
    "\n",
    "- `tf.data.Dataset.range(8).batch(4)`\n",
    "\n",
    "    - 无分布：\n",
    "        - 批次 1：[0, 1, 2, 3]\n",
    "        - 批次 2：[4, 5, 6, 7]\n",
    "    - 分布在 3 个副本上：\n",
    "        - 批次 1：\n",
    "            - 副本 1：[0, 1]\n",
    "            - 副本 2：[2, 3]\n",
    "            - 副本 3：[]\n",
    "        - 批次 2：\n",
    "            - 副本 1：[4, 5]\n",
    "            - 副本 2：[6, 7]\n",
    "            - 副本 3：[]\n",
    "\n",
    "无分布：\n",
    "\n",
    "对数据集进行重新批处理的空间复杂度随副本数量线性增加。对于多工作器训练用例，这意味着输入流水线可能会遇到 OOM 错误。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "IszBuubdtydp"
   },
   "source": [
    "##### 分片\n",
    "\n",
    "`tf.distribute` 还使用 `MultiWorkerMirroredStrategy` 和 `TPUStrategy` 在多工作进程训练中自动分片输入数据集。每个数据集都是在工作进程的 CPU 设备上创建的。在一组工作进程上自动分片数据集意味着每个工作进程都被分配了整个数据集的一个子集（如果设置了正确的 `tf.data.experimental.AutoShardPolicy`）。这是为了确保在每个步骤中，每个工作进程都将处理非重叠数据集元素的全局批次大小。自动分片有几个不同的选项，可以使用 `tf.data.experimental.DistributeOptions` 来指定。请注意，使用 `ParameterServerStrategy` 的多工作进程训练中没有自动分片，有关使用此策略创建数据集的更多信息，请参阅[参数服务器策略教程](parameter_server_training.ipynb)。 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:04.450481Z",
     "iopub.status.busy": "2023-11-07T23:11:04.449754Z",
     "iopub.status.idle": "2023-11-07T23:11:04.460497Z",
     "shell.execute_reply": "2023-11-07T23:11:04.459788Z"
    },
    "id": "jwJtsCQhHK-E"
   },
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n",
    "options = tf.data.Options()\n",
    "options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA\n",
    "dataset = dataset.with_options(options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J7fj3GskHC8g"
   },
   "source": [
    "您可以为 `tf.data.experimental.AutoShardPolicy` 设置三个不同的选项：\n",
    "\n",
    "- AUTO：这是默认选项，意味着将尝试按 FILE 分片。如果未检测到基于文件的数据集，则按 FILE 分片的尝试失败。随后，`tf.distribute` 将退回到按 DATA 分片。请注意，如果输入数据集基于文件，但文件数小于工作进程数，则会引发错误。\n",
    "\n",
    "- FILE：如果您想将输入文件分片到所有工作进程上，则可以使用此选项。如果输入文件的数量远大于工作进程的数量并且文件中的数据均匀分布，则应使用此选项。如果文件中的数据分布不均匀，则此选项的缺点是有空闲的工作进程。如果文件数量小于工作进程数量，则会引发 `InvalidArgumentError`。如果发生这种情况，请将策略显式设置为 `AutoShardPolicy.DATA`。例如，我们将 2 个文件分布在 2 个工作进程上，每个工作进程有 1 个副本。文件 1 包含 [0, 1, 2, 3, 4, 5]，文件 2 包含 [6, 7, 8, 9, 10, 11]。假设同步的副本总数为 2，全局批次大小为 4。\n",
    "\n",
    "    - 工作进程 0：\n",
    "        - 批次 1 =  副本 1：[0, 1]\n",
    "        - 批次 2 =  副本 1：[2, 3]\n",
    "        - 批次 3 = 副本 1：[4]\n",
    "        - 批次 4 = 副本 1：[5]\n",
    "    - 工作进程 1：\n",
    "        - 批次 1 = 副本 2：[6, 7]\n",
    "        - 批次 2 = 副本 2：[8, 9]\n",
    "        - 批次 3 = 副本 2：[10]\n",
    "        - 批次 4 = 副本 2：[11]\n",
    "\n",
    "- DATA：这将在所有工作进程中对元素自动分片。每个工作进程都会读取整个数据集，并且仅处理分配给它的分片。所有其他分片将被丢弃。如果输入文件数小于工作进程数，并且您希望跨所有工作进程对数据更好地分片，通常使用此方法。这种方法的缺点是，将在每个工作进程上读取整个数据集。例如，假设我们将 1 个文件分布到 2 个工作进程中。文件 1 包含 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]。假设同步副本总数为 2。\n",
    "\n",
    "    - 工作进程 0：\n",
    "        - 批次 1 =  副本 1：[0, 1]\n",
    "        - 批次 2 =  副本 1：[4, 5]\n",
    "        - 批次 3 =  副本 1：[8, 9]\n",
    "    - 工作进程 1：\n",
    "        - 批次 1 =  副本 2：[2, 3]\n",
    "        - 批次 2 =  副本 2：[6, 7]\n",
    "        - 批次 3 =  副本 2：[10, 11]\n",
    "\n",
    "- OFF：如果关闭自动分片，则每个工作进程都将处理所有数据。例如，假设我们将 1 个文件分布到 2 个工作进程中。文件 1 包含 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]。假设同步副本总数为 2。那么每个工作器的分布如下：\n",
    "\n",
    "    - 工作进程 0：\n",
    "\n",
    "        - 批次 1 =  副本 1：[0, 1]\n",
    "        - 批次 2 =  副本 1：[2, 3]\n",
    "        - 批次 3 =  副本 1：[4, 5]\n",
    "        - 批次 4 =  副本 1：[6, 7]\n",
    "        - 批次 5 =  副本 1：[8, 9]\n",
    "        - 批次 6 =  副本 1：[10, 11]\n",
    "\n",
    "    - 工作进程 1：\n",
    "\n",
    "        - 批次 1 =  副本 2：[0, 1]\n",
    "        - 批次 2 =  副本 2：[2, 3]\n",
    "        - 批次 3 =  副本 2：[4, 5]\n",
    "        - 批次 4 =  副本 2：[6, 7]\n",
    "        - 批次 5 =  副本 2：[8, 9]\n",
    "        - 批次 6 =  副本 2：[10, 11] "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OK46ZJGPH5H2"
   },
   "source": [
    "##### 预提取\n",
    "\n",
    "默认情况下，`tf.distribute` 会向用户提供的 `tf.data.Dataset` 实例末尾添加预提取转换。预提取转换的参数 `buffer_size` 等于同步副本数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PjiGSY3gtr6_"
   },
   "source": [
    "### `tf.distribute.Strategy.distribute_datasets_from_function`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bAXAo_wWbWSb"
   },
   "source": [
    "#### 用法\n",
    "\n",
    "此 API 使用输入函数并返回 `tf.distribute.DistributedDataset` 实例。用户传入的输入函数具有 `tf.distribute.InputContext` 参数，并且应返回 `tf.data.Dataset` 实例。使用此 API，`tf.distribute` 不会对从输入函数返回的用户 `tf.data.Dataset` 实例进行任何进一步的更改。用户负责对数据集进行批处理和分片。`tf.distribute` 调用每个工作器的 CPU 设备上的输入函数。除了允许用户指定自己的批处理和分片逻辑外，当此 API 用于多工作器训练时，还表现出比 `tf.distribute.Strategy.experimental_distribute_dataset` 更出色的可扩展性和性能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:04.463974Z",
     "iopub.status.busy": "2023-11-07T23:11:04.463676Z",
     "iopub.status.idle": "2023-11-07T23:11:04.479809Z",
     "shell.execute_reply": "2023-11-07T23:11:04.479092Z"
    },
    "id": "9ODch-OFCaW4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "def dataset_fn(input_context):\n",
    "  batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
    "  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(64).batch(16)\n",
    "  dataset = dataset.shard(\n",
    "      input_context.num_input_pipelines, input_context.input_pipeline_id)\n",
    "  dataset = dataset.batch(batch_size)\n",
    "  dataset = dataset.prefetch(2)  # This prefetches 2 batches per device.\n",
    "  return dataset\n",
    "\n",
    "dist_dataset = mirrored_strategy.distribute_datasets_from_function(dataset_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M1bpzPYzt_R7"
   },
   "source": [
    "#### 属性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7cgzhwiiuBvO"
   },
   "source": [
    "##### 批处理\n",
    "\n",
    "应当使用每个副本的批次大小对作为输入函数返回值的 `tf.data.Dataset` 实例进行批处理。每个副本的批次大小等于全局批次大小除以参与同步训练的副本数。这是因为 `tf.distribute` 会在每个工作进程的 CPU 设备上调用输入函数。在给定工作进程上创建的数据集应准备好供该工作进程上的所有副本使用。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "e-wlFFZbP33n"
   },
   "source": [
    "##### 分片\n",
    "\n",
    "`tf.distribute.InputContext` 对象由 `tf.distribute` 在后台创建，它作为参数隐式传递到用户的输入函数。它包含有关工作器数、当前工作器 ID 等方面的信息。此输入函数可以根据用户使用这些属性（属于 `tf.distribute.InputContext` 对象的一部分）设置的策略来处理分片。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7TGwnDM-ICHf"
   },
   "source": [
    "##### 预提取\n",
    "\n",
    "`tf.distribute` 不会在用户提供的输入函数所返回的 `tf.data.Dataset` 的末尾添加预提取转换，因此您需要在上例中显式调用 `Dataset.prefetch`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iOMsf8kyZZpv"
   },
   "source": [
    "注：`tf.distribute.Strategy.experimental_distribute_dataset` 和 `tf.distribute.Strategy.distribute_datasets_from_function` 都会返回不属于 `tf.data.Dataset` 类型的 **`tf.distribute.DistributedDataset` 实例。您可以对这些实例进行迭代（如分布式迭代器部分中所示）并使用 `element_spec` 属性。** "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dL3XbI1gzEjO"
   },
   "source": [
    "## 分布式迭代器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "w8y54-o9T2Ni"
   },
   "source": [
    "与非分布式 `tf.data.Dataset` 实例类似，您将需要在 `tf.distribute.DistributedDataset` 实例上创建一个迭代器以对其进行迭代，并访问 `tf.distribute.DistributedDataset` 中的元素。下面是创建 `tf.distribute.DistributedIterator` 并将其用于训练模型的方法：\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FlKh8NV0uOtZ"
   },
   "source": [
    "### 用法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eSZz6EqOuSlB"
   },
   "source": [
    "#### 使用 Python 式 for 循环结构\n",
    "\n",
    "您可以使用用户友好的 Python 式循环对 `tf.distribute.DistributedDataset` 进行迭代。从 `tf.distribute.DistributedIterator` 返回的元素可以是单个 `tf.Tensor` 或包含每个副本的值的 `tf.distribute.DistributedValues`。将循环放置在 `tf.function` 内有助于提高性能。但是，目前不支持对放置在 `tf.function` 内的 `tf.distribute.DistributedDataset` 的循环使用 `break` 和 `return`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:04.484073Z",
     "iopub.status.busy": "2023-11-07T23:11:04.483824Z",
     "iopub.status.idle": "2023-11-07T23:11:05.001871Z",
     "shell.execute_reply": "2023-11-07T23:11:05.000921Z"
    },
    "id": "zt3AHb46Tr3w"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
      "  1: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
      "  2: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32),\n",
      "  3: tf.Tensor([[0.7]], shape=(1, 1), dtype=float32)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "global_batch_size = 16\n",
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
    "\n",
    "@tf.function\n",
    "def train_step(inputs):\n",
    "  features, labels = inputs\n",
    "  return labels - 0.3 * features\n",
    "\n",
    "for x in dist_dataset:\n",
    "  # train_step trains the model using the dataset elements\n",
    "  loss = mirrored_strategy.run(train_step, args=(x,))\n",
    "  print(\"Loss is \", loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NchPwTEiuSqb"
   },
   "source": [
    "#### 使用 `iter` 创建显式迭代器\n",
    "\n",
    "要迭代 `tf.distribute.DistributedDataset` 实例中的元素，您可以在该实例上使用 `iter` API 创建一个 `tf.distribute.DistributedIterator`。使用显式迭代器，您可以迭代固定数量的步骤。为了从 `tf.distribute.DistributedIterator` 实例 `dist_iterator` 获取下一个元素，您可以调用 `next(dist_iterator)`、`dist_iterator.get_next()` 或 `dist_iterator.get_next_as_optional()`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:05.006566Z",
     "iopub.status.busy": "2023-11-07T23:11:05.005902Z",
     "iopub.status.idle": "2023-11-07T23:11:08.370061Z",
     "shell.execute_reply": "2023-11-07T23:11:08.369141Z"
    },
    "id": "OrMmakq5EqeQ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n",
      "Loss is  PerReplica:{\n",
      "  0: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  1: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  2: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32),\n",
      "  3: tf.Tensor(\n",
      "[[0.7]\n",
      " [0.7]\n",
      " [0.7]\n",
      " [0.7]], shape=(4, 1), dtype=float32)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 10\n",
    "steps_per_epoch = 5\n",
    "for epoch in range(num_epochs):\n",
    "  dist_iterator = iter(dist_dataset)\n",
    "  for step in range(steps_per_epoch):\n",
    "    # train_step trains the model using the dataset elements\n",
    "    loss = mirrored_strategy.run(train_step, args=(next(dist_iterator),))\n",
    "    # which is the same as\n",
    "    # loss = mirrored_strategy.run(train_step, args=(dist_iterator.get_next(),))\n",
    "    print(\"Loss is \", loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UpJXIlxjqPYg"
   },
   "source": [
    "使用 `next()` 或 `tf.distribute.DistributedIterator.get_next` 时，如果 `tf.distribute.DistributedIterator` 已到达末尾，将引发 OutOfRange 错误。客户端可以在 Python 端捕获该错误，并继续执行其他工作，例如设置检查点和评估。但是，如果您使用的是主机训练循环（即，每个 `tf.function` 运行多个步骤），这种方式将不会奏效，如下所示：\n",
    "\n",
    "```\n",
    "@tf.function\n",
    "def train_fn(iterator):\n",
    "  for _ in tf.range(steps_per_loop):\n",
    "    strategy.run(step_fn, args=(next(iterator),))\n",
    "```\n",
    "\n",
    "`train_fn` 通过将步骤主体封装在 `tf.range` 中来包含多个步骤。在这种情况下，循环中没有依赖项的不同迭代可以并行开始，因此会在先前迭代的计算完成之前在后续的迭代中触发 OutOfRange 错误。一旦抛出 OutOfRange 错误，函数中的所有运算都会立即终止。如果您想要避免这种情况，则不抛出 OutOfRange 错误的替代方案为 `tf.distribute.DistributedIterator.get_next_as_optional`。`get_next_as_optional` 返回 `tf.experimental.Optional`，其中包含下一个元素或者不包含任何值（如果 `tf.distribute.DistributedIterator` 已到达末尾）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:08.374448Z",
     "iopub.status.busy": "2023-11-07T23:11:08.374157Z",
     "iopub.status.idle": "2023-11-07T23:11:09.060161Z",
     "shell.execute_reply": "2023-11-07T23:11:09.059317Z"
    },
    "id": "Iyjao96Vqwyz"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([0], [1], [2], [3])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([4], [5], [6], [7])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "([8], [], [], [])\n"
     ]
    }
   ],
   "source": [
    "# You can break the loop with `get_next_as_optional` by checking if the `Optional` contains a value\n",
    "global_batch_size = 4\n",
    "steps_per_loop = 5\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "dataset = tf.data.Dataset.range(9).batch(global_batch_size)\n",
    "distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n",
    "\n",
    "@tf.function\n",
    "def train_fn(distributed_iterator):\n",
    "  for _ in tf.range(steps_per_loop):\n",
    "    optional_data = distributed_iterator.get_next_as_optional()\n",
    "    if not optional_data.has_value():\n",
    "      break\n",
    "    per_replica_results = strategy.run(lambda x: x, args=(optional_data.get_value(),))\n",
    "    tf.print(strategy.experimental_local_results(per_replica_results))\n",
    "train_fn(distributed_iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LaclbKnqzLjf"
   },
   "source": [
    "## 使用 `element_spec` 属性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z1YvXqOpwy08"
   },
   "source": [
    "如果将分布式数据集的元素传递给 `tf.function` 并且需要 `tf.TypeSpec` 保证，则可以指定 `tf.function` 的 `input_signature` 参数。分布式数据集的输出为 `tf.distribute.DistributedValues`，它可以表示单个设备或多个设备的输入。要获取与此分布式值相对应的 `tf.TypeSpec`，可以使用 `tf.distribute.DistributedDataset.element_spec` 或 `tf.distribute.DistributedIterator.element_spec`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:09.064905Z",
     "iopub.status.busy": "2023-11-07T23:11:09.064122Z",
     "iopub.status.idle": "2023-11-07T23:11:11.131319Z",
     "shell.execute_reply": "2023-11-07T23:11:11.130485Z"
    },
    "id": "pg3B-Cw_cn3a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "},\n",
      " PerReplica:{\n",
      "  0: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  1: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  2: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>,\n",
      "  3: <tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
      "array([[1.],\n",
      "       [1.],\n",
      "       [1.],\n",
      "       [1.]], dtype=float32)>\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "global_batch_size = 16\n",
    "epochs = 5\n",
    "steps_per_epoch = 5\n",
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(global_batch_size)\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
    "\n",
    "@tf.function(input_signature=[dist_dataset.element_spec])\n",
    "def train_step(per_replica_inputs):\n",
    "  def step_fn(inputs):\n",
    "    return 2 * inputs\n",
    "\n",
    "  return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))\n",
    "\n",
    "for _ in range(epochs):\n",
    "  iterator = iter(dist_dataset)\n",
    "  for _ in range(steps_per_epoch):\n",
    "    output = train_step(next(iterator))\n",
    "    tf.print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-OAa6svUzuWm"
   },
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pSMrs3kJQexW"
   },
   "source": [
    "目前为止，您已经学习了如何分布 `tf.data.Dataset`。但在数据准备好用于模型之前，还需要对其进行预处理，例如对数据进行清理、转换和扩充。以下是两套方便的预处理工具：\n",
    "\n",
    "- [Keras 预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)：一组可供开发者构建 Keras 原生输入处理流水线的 Keras 层。 一些 Keras 预处理层包含不可训练的状态，可以在初始化时设置或进行 `adapt`（请参阅 [Keras 预处理层指南](https://tensorflow.google.cn/guide/keras/preprocessing_layers)的 `adapt` 部分）。在分布有状态预处理层时，应将状态复制到所有工作进程。要使用这些层，您可以使其成为模型的一部分或将其应用于数据集。\n",
    "\n",
    "- [TensorFlow Transform (tf.Transform)](https://tensorflow.google.cn/tfx/transform/get_started)：可供您通过数据预处理流水线定义实例级和全通数据转换的 TensorFlow 库。TensorFlow Transform 包含两个阶段。第一个阶段为分析阶段，该阶段会在全通进程中分析原始训练数据，以计算转换所需的统计数据，并会生成转换逻辑作为实例级运算。第二个阶段为转换阶段，该阶段会在实例级进程中转换原始训练数据。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pd4aUCFdVlZ1"
   },
   "source": [
    "### Keras 预处理层与 TensorFlow Transform\n",
    "\n",
    "TensorFlow Transform 和 Keras 预处理层均支持在训练期间拆分预处理，并在推断期间将预处理与模型捆绑在一起，从而降低训练/应用偏差。\n",
    "\n",
    "TensorFlow Transform 已与 [TFX](https://tensorflow.google.cn/tfx) 深度集成，提供了一项可扩缩的映射-归约解决方案，可在与训练流水线分开的作业中分析和转换任何大小的数据集。如果您需要运行的数据集分析不适合在单台机器上进行，则 TensorFlow Transform 应是您的首选。\n",
    "\n",
    "Keras 预处理层则更适于首先从磁盘读取数据，然后在训练期间应用的预处理。它们能够无缝适配 Keras 库中的模型开发。它们支持通过 [`adapt`](https://tensorflow.google.cn/guide/keras/preprocessing_layers#the_adapt_method) 来分析较小的数据集，并支持诸如图像数据扩充等用例，在图像数据扩充中，每次传递输入数据集都会产生不同的训练样本。\n",
    "\n",
    "这两个库也可以混合使用，其中 TensorFlow Transform 用于输入数据分析和静态转换，Keras 预处理层用于训练时转换（例如，独热编码或数据扩充）。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MReKhhZpHUpj"
   },
   "source": [
    "### tf.distribute 最佳做法\n",
    "\n",
    "使用这两种工具都需要初始化应用于数据的转换逻辑，这可能会创建 TensorFlow 资源。这些资源或状态应复制到所有工作进程，以节省工作进程间或工作进程-协调器间的通信。为此，建议您在 `tf.distribute.Strategy.scope` 下创建 Keras 预处理层 `tft.TFTransformOutput.transform_features_layer` 或 `tft.TransformFeaturesLayer`，就像创建任何其他 Keras 层一样。\n",
    "\n",
    "以下示例分别演示了 `tf.distribute.Strategy` API 与高级 Keras `Model.fit` API 以及与自定义训练循环配合使用的用法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rwEGMWuoX7kJ"
   },
   "source": [
    "#### 针对 Keras 预处理层用户的额外说明：\n",
    "\n",
    "**预处理层和大型词汇表**\n",
    "\n",
    "在多工作进程环境（例如，`tf.distribute.MultiWorkerMirroredStrategy`、`tf.distribute.experimental.ParameterServerStrategy`、`tf.distribute.TPUStrategy`）中处理大型词汇表（超过 1 GB）时，建议将词汇表保存至所有工作进程均可访问的静态文件中（例如，使用 Cloud Storage）。这将减少在训练期间向所有工作进程复制词汇表所花费的时间。\n",
    "\n",
    "**`tf.data` 流水线中的预处理与模型中的预处理**\n",
    "\n",
    "Keras 预处理层既可以作为模型的一部分应用，也可以直接应用于 `tf.data.Dataset`，但每种选项各具优势：\n",
    "\n",
    "- 在模型中应用预处理层可以使您的模型具备可移植性，并有助于减少训练/应用偏差。（有关详情，请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers#benefits_of_doing_preprocessing_inside_the_model_at_inference_time)指南中的*推断时在模型内部进行预处理的好处*部分）\n",
    "- 在 `tf.data` 流水线中应用预处理可以预提取或卸载至 CPU，这通常可以在使用加速器时提高性能。\n",
    "\n",
    "在一个或多个 TPU 上运行时，用户几乎应始终将 Keras 预处理层置于 `tf.data` 流水线内，因为并非所有层都支持 TPU，并且无法在 TPU 上执行字符串运算。（`tf.keras.layers.Normalization` 和 `tf.keras.layers.Rescaling` 是两个例外，它们在 TPU 上运行良好，并且常被用作图像模型中的第一层。）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hNCYZ9L-BD2R"
   },
   "source": [
    "### 使用 `Model.fit` 进行预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NhRB2Xe8B6bX"
   },
   "source": [
    "使用 Keras `Model.fit` 时，您不需要使用 `tf.distribute.Strategy.experimental_distribute_dataset` 或 `tf.distribute.Strategy.distribute_datasets_from_function` 自行分布数据。请参阅[使用预处理层](https://tensorflow.google.cn/guide/keras/preprocessing_layers)指南和[使用 Keras 进行分布式训练](https://tensorflow.google.cn/tutorials/distribute/keras)指南以了解详情。一个简短的示例如下所示：\n",
    "\n",
    "```\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "with strategy.scope():\n",
    "  # Create the layer(s) under scope.\n",
    "  integer_preprocessing_layer = tf.keras.layers.IntegerLookup(vocabulary=FILE_PATH)\n",
    "  model = ...\n",
    "  model.compile(...)\n",
    "dataset = dataset.map(lambda x, y: (integer_preprocessing_layer(x), y))\n",
    "model.fit(dataset)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3zL2vzJ-G0yg"
   },
   "source": [
    "使用 `tf.distribute.experimental.ParameterServerStrategy` 和 `Model.fit` API 的用户需要使用 `tf.keras.utils.experimental.DatasetCreator` 作为输入。（请参阅[参数服务器训练](https://tensorflow.google.cn/tutorials/distribute/parameter_server_training#parameter_server_training_with_modelfit_api)指南以了解详情。）\n",
    "\n",
    "```\n",
    "strategy = tf.distribute.experimental.ParameterServerStrategy(\n",
    "    cluster_resolver,\n",
    "    variable_partitioner=variable_partitioner)\n",
    "\n",
    "with strategy.scope():\n",
    "  preprocessing_layer = tf.keras.layers.StringLookup(vocabulary=FILE_PATH)\n",
    "  model = ...\n",
    "  model.compile(...)\n",
    "\n",
    "def dataset_fn(input_context):\n",
    "  ...\n",
    "  dataset = dataset.map(preprocessing_layer)\n",
    "  ...\n",
    "  return dataset\n",
    "\n",
    "dataset_creator = tf.keras.utils.experimental.DatasetCreator(dataset_fn)\n",
    "model.fit(dataset_creator, epochs=5, steps_per_epoch=20, callbacks=callbacks)\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "imZLQUOYBJyW"
   },
   "source": [
    "### 使用自定义训练循环进行预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "r2PX1QH_OwU3"
   },
   "source": [
    "编写[自定义训练循环](https://tensorflow.google.cn/tutorials/distribute/custom_training)时，您将使用 `tf.distribute.Strategy.experimental_distribute_dataset` API 或 `tf.distribute.Strategy.distribute_datasets_from_function` API 分布数据。如果您通过 `tf.distribute.Strategy.experimental_distribute_dataset` 分布数据集，则在数据流水线中应用这些预处理 API 将导致资源自动与数据流水线归于同一位置，以避免远程资源访问。因此，这里的示例都将使用 `tf.distribute.Strategy.distribute_datasets_from_function`，在这种情况下，必须在 `strategy.scope()` 下放置这些 API 的初始化以提高效率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:11.136642Z",
     "iopub.status.busy": "2023-11-07T23:11:11.135877Z",
     "iopub.status.idle": "2023-11-07T23:11:11.682141Z",
     "shell.execute_reply": "2023-11-07T23:11:11.681340Z"
    },
    "id": "wJS1UmcWQeab"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PerReplica:{\n",
      "  0: tf.Tensor([1], shape=(1,), dtype=int64),\n",
      "  1: tf.Tensor([3], shape=(1,), dtype=int64),\n",
      "  2: tf.Tensor([0], shape=(1,), dtype=int64),\n",
      "  3: tf.Tensor([1], shape=(1,), dtype=int64)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor([3], shape=(1,), dtype=int64),\n",
      "  1: tf.Tensor([0], shape=(1,), dtype=int64),\n",
      "  2: tf.Tensor([1], shape=(1,), dtype=int64),\n",
      "  3: tf.Tensor([3], shape=(1,), dtype=int64)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor([0], shape=(1,), dtype=int64),\n",
      "  1: tf.Tensor([1], shape=(1,), dtype=int64),\n",
      "  2: tf.Tensor([3], shape=(1,), dtype=int64),\n",
      "  3: tf.Tensor([0], shape=(1,), dtype=int64)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "strategy = tf.distribute.MirroredStrategy()\n",
    "vocab = [\"a\", \"b\", \"c\", \"d\", \"f\"]\n",
    "\n",
    "with strategy.scope():\n",
    "  # Create the layer(s) under scope.\n",
    "  layer = tf.keras.layers.StringLookup(vocabulary=vocab)\n",
    "\n",
    "def dataset_fn(input_context):\n",
    "  # a tf.data.Dataset\n",
    "  dataset = tf.data.Dataset.from_tensor_slices([\"a\", \"c\", \"e\"]).repeat()\n",
    "\n",
    "  # Custom your batching, sharding, prefetching, etc.\n",
    "  global_batch_size = 4\n",
    "  batch_size = input_context.get_per_replica_batch_size(global_batch_size)\n",
    "  dataset = dataset.batch(batch_size)\n",
    "  dataset = dataset.shard(\n",
    "      input_context.num_input_pipelines,\n",
    "      input_context.input_pipeline_id)\n",
    "\n",
    "  # Apply the preprocessing layer(s) to the tf.data.Dataset\n",
    "  def preprocess_with_kpl(input):\n",
    "    return layer(input)\n",
    "\n",
    "  processed_ds = dataset.map(preprocess_with_kpl)\n",
    "  return processed_ds\n",
    "\n",
    "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n",
    "\n",
    "# Print out a few example batches.\n",
    "distributed_dataset_iterator = iter(distributed_dataset)\n",
    "for _ in range(3):\n",
    "  print(next(distributed_dataset_iterator))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PVl1cblWQy8b"
   },
   "source": [
    "请注意，如果您使用 `tf.distribute.experimental.ParameterServerStrategy` 进行训练，那么您还将调用 `tf.distribute.experimental.coordinator.ClusterCoordinator.create_per_worker_dataset`\n",
    "\n",
    "```\n",
    "@tf.function\n",
    "def per_worker_dataset_fn():\n",
    "  return strategy.distribute_datasets_from_function(dataset_fn)\n",
    "\n",
    "per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)\n",
    "per_worker_iterator = iter(per_worker_dataset)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ol7SmPID1dAt"
   },
   "source": [
    "对于 TensorFlow Transform，如上所述，分析阶段会与训练分开完成，因此在此省略。有关详细的操作方法，请参阅[教程](https://tensorflow.google.cn/tfx/tutorials/transform/census)。通常，此阶段包括创建 `tf.Transform` 预处理函数，以及使用此预处理函数转换 [Apache Beam](https://beam.apache.org/) 流水线中的数据。在分析阶段结束时，可以将输出导出为 TensorFlow 计算图，进而用于训练和应用。我们的示例仅涵盖了训练流水线部分：\n",
    "\n",
    "```\n",
    "with strategy.scope():\n",
    "  # working_dir contains the tf.Transform output.\n",
    "  tf_transform_output = tft.TFTransformOutput(working_dir)\n",
    "  # Loading from working_dir to create a Keras layer for applying the tf.Transform output to data\n",
    "  tft_layer = tf_transform_output.transform_features_layer()\n",
    "  ...\n",
    "\n",
    "def dataset_fn(input_context):\n",
    "  ...\n",
    "  dataset.map(tft_layer, num_parallel_calls=tf.data.AUTOTUNE)\n",
    "  ...\n",
    "  return dataset\n",
    "\n",
    "distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3_IQxRXxQWof"
   },
   "source": [
    "## 部分批次"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hW2_gVkiztUG"
   },
   "source": [
    "当 1) 用户创建的 `tf.data.Dataset` 实例包含的批次大小不能被副本数整除，或者 2) 数据集实例的基数不能被批次大小整除时，将遇到部分批次。这意味着，当数据集分布在多个副本上时，某些迭代器上的 `next` 调用将导致 `tf.errors.OutOfRangeError`。要处理此用例，`tf.distribute` 会在没有更多数据要处理的副本上返回批次大小为 `0` 的虚拟批次。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rqutdpqtPcCH"
   },
   "source": [
    "对于单工作进程情况，如果迭代器上的 `next` 调用未返回数据，则会创建批次大小为 0 的虚拟批次，并将其与数据集中的实际数据一起使用。在部分批次的情况下，数据的最后一个全局批次将包含实际数据以及虚拟数据批次。现在，用于处理数据的停止条件会检查是否有任何副本具有数据。如果任何副本上都没有数据，则会出现 `tf.errors.OutOfRangeError` 错误。\n",
    "\n",
    "对于多工作进程情况，使用跨副本通信聚合表示每个工作进程上数据存在的布尔值，该布尔值用于标识所有工作进程是否已完成对分布式数据集的处理。由于这涉及跨工作进程通信，因此会涉及一些性能损失。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vehLsgljz90Y"
   },
   "source": [
    "## 警告"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Nx4jyN_Az-Dy"
   },
   "source": [
    "- 将 `tf.distribute.Strategy.experimental_distribute_dataset` API 与多工作进程环境结合使用时，您会传递从文件读取的 `tf.data.Dataset`。如果 `tf.data.experimental.AutoShardPolicy` 设置为 `AUTO` 或 `FILE`，则实际的每步批次大小可能会小于您为全局批次大小定义的值。当文件中的剩余元素小于全局批次大小时，可能会发生这种情况。您可以在不依赖于运行步数的情况下耗尽数据集，也可以通过将 `tf.data.experimental.AutoShardPolicy` 设置为 `DATA` 来解决。\n",
    "\n",
    "- `tf.distribute` 当前不支持有状态数据集转换，并且当前将忽略数据集可能具有的任何有状态运算。例如，如果您的数据集包含使用 `tf.random.uniform` 来旋转图像的 `map_fn`，您的数据集计算图将依赖于执行 Python 进程的本地机器上的状态（即，随机种子）。\n",
    "\n",
    "- 默认停用的实验性 `tf.data.experimental.OptimizationOptions` 在某些上下文中（例如与 `tf.distribute` 一起使用时）可能会导致性能下降。只有在分布设置中验证它们有利于您的工作负载性能后，才应将其启用。\n",
    "\n",
    "- 请参阅[这篇指南](https://tensorflow.google.cn/guide/data_performance)，了解如何使用 `tf.data` 优化您的输入流水线。一些附加提示：\n",
    "\n",
    "    - 如果您有多个工作进程并且正在使用 `tf.data.Dataset.list_files` 从匹配一个或多个 glob 模式的所有文件创建数据集，请记住设置 `seed` 参数或设置 `shuffle=False`，这样每个工作进程才能一致地分片文件。\n",
    "\n",
    "- 如果您的输入流水线包括在记录级别上打乱数据的顺序和解析数据，除非未解析的数据明显大于已解析的数据（通常不是这种情况），否则请先打乱数据，然后再解析，如下面的示例中所示。这样做对内存使用率和性能有利。\n",
    "\n",
    "```\n",
    "d = tf.data.Dataset.list_files(pattern, shuffle=False)\n",
    "d = d.shard(num_workers, worker_index)\n",
    "d = d.repeat(num_epochs)\n",
    "d = d.shuffle(shuffle_buffer_size)\n",
    "d = d.interleave(tf.data.TFRecordDataset,\n",
    "                 cycle_length=num_readers, block_length=1)\n",
    "d = d.map(parser_fn, num_parallel_calls=num_map_threads)\n",
    "```\n",
    "\n",
    "- `tf.data.Dataset.shuffle(buffer_size, seed=None, reshuffle_each_iteration=None)` 维持 `buffer_size` 元素的内部缓冲区，因此减小 `buffer_size` 可以缓解 OOM 问题。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dAC_vRmJyzrB"
   },
   "source": [
    "- 使用 `tf.distribute.experimental_distribute_dataset` 或 `tf.distribute.distribute_datasets_from_function` 时，工作进程处理数据的顺序无法得到保证。如果您使用 `tf.distribute` 来扩展预测，这通常是必需的。但是，您可以为批次中的每个元素插入索引并相应地对输出进行排序。以下代码段是如何对输出进行排序的示例。\n",
    "\n",
    "注：为方便起见，此处使用 `tf.distribute.MirroredStrategy`。仅当您使用多工作进程，但将 `tf.distribute.MirroredStrategy` 用于在单工作进程上分布训练时，才需要对输入重新排序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:11.687311Z",
     "iopub.status.busy": "2023-11-07T23:11:11.686610Z",
     "iopub.status.idle": "2023-11-07T23:11:12.071571Z",
     "shell.execute_reply": "2023-11-07T23:11:12.070807Z"
    },
    "id": "Zr2xAy-uZZaL"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 0, 1: 2, 2: 4, 3: 6, 4: 8, 5: 10, 6: 12, 7: 14, 8: 16, 9: 18, 10: 20, 11: 22, 12: 24, 13: 26, 14: 28, 15: 30, 16: 32, 17: 34, 18: 36, 19: 38, 20: 40, 21: 42, 22: 44, 23: 46}\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "dataset_size = 24\n",
    "batch_size = 6\n",
    "dataset = tf.data.Dataset.range(dataset_size).enumerate().batch(batch_size)\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
    "\n",
    "def predict(index, inputs):\n",
    "  outputs = 2 * inputs\n",
    "  return index, outputs\n",
    "\n",
    "result = {}\n",
    "for index, inputs in dist_dataset:\n",
    "  output_index, outputs = mirrored_strategy.run(predict, args=(index, inputs))\n",
    "  indices = list(mirrored_strategy.experimental_local_results(output_index))\n",
    "  rindices = []\n",
    "  for a in indices:\n",
    "    rindices.extend(a.numpy())\n",
    "  outputs = list(mirrored_strategy.experimental_local_results(outputs))\n",
    "  routputs = []\n",
    "  for a in outputs:\n",
    "    routputs.extend(a.numpy())\n",
    "  for i, value in zip(rindices, routputs):\n",
    "    result[i] = value\n",
    "\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nNbn7HXx0YqB"
   },
   "source": [
    "<a name=\"tensorinputs\"> ## 张量输入而非 tf.data </a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dymZixqo0nKK"
   },
   "source": [
    "有时用户无法使用 `tf.data.Dataset` 表示其输入，随后也无法使用上述 API 将数据集分布到多个设备。在这种情况下，您可以使用原始张量或来自生成器的输入。\n",
    "\n",
    "### 将 experimental_distribute_values_from_function 用于任意张量输入\n",
    "\n",
    "`strategy.run` 接受 `tf.distribute.DistributedValues`，它是 `next(iterator)` 的输出。要传递张量值，请使用 `tf.distribute.Strategy.experimental_distribute_values_from_function` 从原始张量构造 `tf.distribute.DistributedValues`。用户必须使用此选项在输入函数中指定自己的批处理和分片逻辑，这可以使用 `tf.distribute.experimental.ValueContext` 输入对象来完成。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:12.075564Z",
     "iopub.status.busy": "2023-11-07T23:11:12.074987Z",
     "iopub.status.idle": "2023-11-07T23:11:12.096076Z",
     "shell.execute_reply": "2023-11-07T23:11:12.095349Z"
    },
    "id": "ajZHNRQs0kqm"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PerReplica:{\n",
      "  0: tf.Tensor(0, shape=(), dtype=int32),\n",
      "  1: tf.Tensor(1, shape=(), dtype=int32),\n",
      "  2: tf.Tensor(2, shape=(), dtype=int32),\n",
      "  3: tf.Tensor(3, shape=(), dtype=int32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor(0, shape=(), dtype=int32),\n",
      "  1: tf.Tensor(1, shape=(), dtype=int32),\n",
      "  2: tf.Tensor(2, shape=(), dtype=int32),\n",
      "  3: tf.Tensor(3, shape=(), dtype=int32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor(0, shape=(), dtype=int32),\n",
      "  1: tf.Tensor(1, shape=(), dtype=int32),\n",
      "  2: tf.Tensor(2, shape=(), dtype=int32),\n",
      "  3: tf.Tensor(3, shape=(), dtype=int32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor(0, shape=(), dtype=int32),\n",
      "  1: tf.Tensor(1, shape=(), dtype=int32),\n",
      "  2: tf.Tensor(2, shape=(), dtype=int32),\n",
      "  3: tf.Tensor(3, shape=(), dtype=int32)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "def value_fn(ctx):\n",
    "  return tf.constant(ctx.replica_id_in_sync_group)\n",
    "\n",
    "distributed_values = mirrored_strategy.experimental_distribute_values_from_function(value_fn)\n",
    "for _ in range(4):\n",
    "  result = mirrored_strategy.run(lambda x: x, args=(distributed_values,))\n",
    "  print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P98aFQGf0x_7"
   },
   "source": [
    "### 如果您的输入来自生成器，则使用 tf.data.Dataset.from_generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "emZCWQSi04qT"
   },
   "source": [
    "如果您具有要使用的生成器函数，则可以使用 `from_generator` API 创建一个 `tf.data.Dataset` 实例。\n",
    "\n",
    "注：`tf.distribute.TPUStrategy` 当前不支持此功能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-07T23:11:12.099842Z",
     "iopub.status.busy": "2023-11-07T23:11:12.099535Z",
     "iopub.status.idle": "2023-11-07T23:11:12.504139Z",
     "shell.execute_reply": "2023-11-07T23:11:12.503218Z"
    },
    "id": "jRhU0X230787"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PerReplica:{\n",
      "  0: tf.Tensor([0.26498944], shape=(1,), dtype=float32),\n",
      "  1: tf.Tensor([0.9832243], shape=(1,), dtype=float32),\n",
      "  2: tf.Tensor([0.7569181], shape=(1,), dtype=float32),\n",
      "  3: tf.Tensor([0.5905416], shape=(1,), dtype=float32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor([0.2529385], shape=(1,), dtype=float32),\n",
      "  1: tf.Tensor([0.75223196], shape=(1,), dtype=float32),\n",
      "  2: tf.Tensor([0.8507075], shape=(1,), dtype=float32),\n",
      "  3: tf.Tensor([0.35577485], shape=(1,), dtype=float32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor([0.47461054], shape=(1,), dtype=float32),\n",
      "  1: tf.Tensor([0.46633008], shape=(1,), dtype=float32),\n",
      "  2: tf.Tensor([0.2187182], shape=(1,), dtype=float32),\n",
      "  3: tf.Tensor([0.8489092], shape=(1,), dtype=float32)\n",
      "}\n",
      "PerReplica:{\n",
      "  0: tf.Tensor([0.27852485], shape=(1,), dtype=float32),\n",
      "  1: tf.Tensor([0.10208022], shape=(1,), dtype=float32),\n",
      "  2: tf.Tensor([0.5859448], shape=(1,), dtype=float32),\n",
      "  3: tf.Tensor([0.4391938], shape=(1,), dtype=float32)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
    "def input_gen():\n",
    "  while True:\n",
    "    yield np.random.rand(4)\n",
    "\n",
    "# use Dataset.from_generator\n",
    "dataset = tf.data.Dataset.from_generator(\n",
    "    input_gen, output_types=(tf.float32), output_shapes=tf.TensorShape([4]))\n",
    "dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)\n",
    "iterator = iter(dist_dataset)\n",
    "for _ in range(4):\n",
    "  result = mirrored_strategy.run(lambda x: x, args=(next(iterator),))\n",
    "  print(result)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "input.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
